import ufl
import numpy

from mpi4py import MPI
from petsc4py import PETSc

from dolfinx import mesh, fem, io, nls, log
from dolfinx.fem.petsc import NonlinearProblem
from dolfinx.nls.petsc import NewtonSolver


def q(u):
    return 1 + u**2


domain = mesh.create_unit_square(MPI.COMM_WORLD, 5, 5)
x = ufl.SpatialCoordinate(domain)
u_ufl = 1 + x[0] + 2 * x[1]
f = - ufl.div(q(u_ufl) * ufl.grad(u_ufl))

V = fem.functionspace(domain, ("Lagrange", 1))
def u_exact(x): return eval(str(u_ufl))
# 这里是为了支持插值的数据格式...也可以使用lambda函数定义u_exact


u_D = fem.Function(V)
u_D.interpolate(u_exact)  # 插值
fdim = domain.topology.dim - 1  # 边界拓扑维数
boundary_facets = mesh.locate_entities_boundary(
    domain, fdim, lambda x: numpy.full(x.shape[1], True, dtype=bool))
bc = fem.dirichletbc(
    u_D, fem.locate_dofs_topological(V, fdim, boundary_facets))

# 问题是非线性的 非线性项q(u)依赖未知函数本身
# Function是FEM空间存储数值解的对象 TrialFunction是变分形式
uh = fem.Function(V)  # 注意这里使用fem.Function()定义!!!
v = ufl.TestFunction(V)
F = q(uh) * ufl.dot(ufl.grad(uh), ufl.grad(v)) * ufl.dx - f * v * ufl.dx

# Newton迭代
problem = NonlinearProblem(F, uh, bcs=[bc])
solver = NewtonSolver(MPI.COMM_WORLD, problem)
# 依据解的增量判断是否收敛（迭代变化量）
solver.convergence_criterion = "incremental"
solver.rtol = 1e-6  # 相对容差
solver.report = True

ksp = solver.krylov_solver
opts = PETSc.Options()
option_prefix = ksp.getOptionsPrefix()
opts[f"{option_prefix}ksp_type"] = "gmres"
opts[f"{option_prefix}ksp_rtol"] = 1.0e-8
opts[f"{option_prefix}pc_type"] = "hypre"
opts[f"{option_prefix}pc_hypre_type"] = "boomeramg"
opts[f"{option_prefix}pc_hypre_boomeramg_max_iter"] = 1
opts[f"{option_prefix}pc_hypre_boomeramg_cycle_type"] = "v"
ksp.setFromOptions()

log.set_log_level(log.LogLevel.INFO)
n, converged = solver.solve(uh)
assert (converged)
print(f"Number of interations: {n:d}")

# Compute L2 error and error at nodes
V_ex = fem.functionspace(domain, ("Lagrange", 2))
u_ex = fem.Function(V_ex)
u_ex.interpolate(u_exact)
error_local = fem.assemble_scalar(fem.form((uh - u_ex)**2 * ufl.dx))
error_L2 = numpy.sqrt(domain.comm.allreduce(error_local, op=MPI.SUM))
if domain.comm.rank == 0:
    print(f"L2-error: {error_L2:.2e}")

# Compute values at mesh vertices
error_max = domain.comm.allreduce(
    numpy.max(numpy.abs(uh.x.array - u_D.x.array)), op=MPI.MAX)
if domain.comm.rank == 0:
    print(f"Error_max: {error_max:.2e}")
